Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Medusa Implementation with Top-1 proposer #4978

Merged
merged 18 commits into from
Jul 10, 2024

Conversation

abhigoyal1997
Copy link
Contributor

@abhigoyal1997 abhigoyal1997 commented May 22, 2024

This PR implements the Medusa approach to generate speculations using top-1 predictions of the heads.

For Mistral-7B-Instruct-v0.2 and Meta-Llama-3-8B-Instruct on an H100 card, the following are the throughput numbers (total tokens generated/sec) when tested on MT-Bench:

  Temperature              
Concurrency   64 32 16 8 4 2 1
Mistral-7B-Instruct-v0.2 0 2425.03 1862.19 1262.11 792.05 453.78 245.14 128.73
  1 1876.1 1902.74 1281.7 799.73 452.16 245.8 126.73
                 
Mistral-7B-Instruct-v0.2 + medusa 0 1769.62 1681.17 1318.1 911.77 550.85 293.14 155.06
  1 1818.85 1739.18 1414.85 911.56 554.67 297.65 164.38
                 
Meta-Llama-3-8B-Instruct 0 2275.64 1796.47 1258.82 760.38 429.29 235.47 121.27
  1 2291.75 1712.22 1246.62 740.42 420.49 236.39 121.05
                 
Meta-Llama-3-8B-Instruct + medusa 0 1533.34 1445.23 1156.08 830.04 490.86 269.09 141.33
  1 1606.42 1457.75 1185.55 838.49 512.31 281.95 147.08

So for smaller batch sizes, we see improvement in tokens generated per sec.
Medusa heads for these models were trained using a set of public instructions. I am working on making those available via Huggingface Hub as well.

With tree-style speculation and verification, this should give even higher improvements.

FIX #1023
FIX #4669
Fix FasterDecoding/Medusa#41

@abhigoyal1997 abhigoyal1997 mentioned this pull request May 22, 2024
65 tasks
@caddfa31434
Copy link

@abhigoyal1997 Awesome PR.Do you have any plans regarding support for tree-style speculation and verification?

@abhigoyal1997
Copy link
Contributor Author

Hi @caddfa31434
Tree-style speculation and verification are certainly something I am interested in. But for complete benefit, I was still waiting on the support for tree attention. I have implemented it in a PyTorch implementation but haven't thought about how it would work in vLLM.

@zhyncs
Copy link
Contributor

zhyncs commented May 23, 2024

Hi @caddfa31434 Tree-style speculation and verification are certainly something I am interested in. But for complete benefit, I was still waiting on the support for tree attention. I have implemented it in a PyTorch implementation but haven't thought about how it would work in vLLM.

Hi @abhigoyal1997 We recently had the experience of developing a TreeMask-based version of Medusa on our internal inference framework, and it is expected to be ready by the end of the month. Here are some suggestions.

There are three main differences between the TreeMask version and the non-TreeMask version (three modification points):

  1. The Cos/Sin matrix for RoPE
  2. Replacing Casual Mask with Tree Mask in Attention
  3. Updating KV Cache, which needs to be compacted after each verification

Additionally, EAGLE is a technology worth considering for selection. Therefore, I suggest you make some abstractions when implementing Medusa, which will make future integration with EAGLE easier.

@abhigoyal1997
Copy link
Contributor Author

@zhyncs Thanks for the insights into tree-mask-based version.
As for EAGLE, I have an implementation ready for that as well. This implementation of Medusa does keep required abstractions in mind such that any speculative models which need some inputs from the target model can be implemented with ease. In fact the implementation of EAGLE I have (I'll raise a PR for that as well soon), only needs addition of a new model to vLLM without any change to the implementation logic of speculative decoding itself.

@zhyncs
Copy link
Contributor

zhyncs commented May 23, 2024

@zhyncs Thanks for the insights into tree-mask-based version. As for EAGLE, I have an implementation ready for that as well. This implementation of Medusa does keep required abstractions in mind such that any speculative models which need some inputs from the target model can be implemented with ease. In fact the implementation of EAGLE I have (I'll raise a PR for that as well soon), only needs addition of a new model to vLLM without any change to the implementation logic of speculative decoding itself.

yep. And in our scenario, EAGLE performs better than Medusa.

@abhigoyal1997
Copy link
Contributor Author

@cadedaniel This is complete. Can you please review it?

@cadedaniel
Copy link
Collaborator

Thanks for the contribution! Will take a look today or tomorrow.

@cadedaniel cadedaniel self-requested a review May 23, 2024 17:30
@cadedaniel cadedaniel self-assigned this May 23, 2024
@KexinFeng
Copy link

Indeed, there are several challenges in the tree-style speculative decoding, including but not limited to what mentioned here. Especially when large batch size is considered, different request may have different acceptance path in the candidate tree. How to efficiently process them will be an issue. I'm focusing on solving them right now, and should be able to upstream my solution to vllm soon. See also #4669 (comment)

Hi @caddfa31434 Tree-style speculation and verification are certainly something I am interested in. But for complete benefit, I was still waiting on the support for tree attention. I have implemented it in a PyTorch implementation but haven't thought about how it would work in vLLM.

Hi @abhigoyal1997 We recently had the experience of developing a TreeMask-based version of Medusa on our internal inference framework, and it is expected to be ready by the end of the month. Here are some suggestions.

There are three main differences between the TreeMask version and the non-TreeMask version (three modification points):

  1. The Cos/Sin matrix for RoPE
  2. Replacing Casual Mask with Tree Mask in Attention
  3. Updating KV Cache, which needs to be compacted after each verification

Additionally, EAGLE is a technology worth considering for selection. Therefore, I suggest you make some abstractions when implementing Medusa, which will make future integration with EAGLE easier.

@cadedaniel
Copy link
Collaborator

+1; I suggest we generalize top-1 and top-k proposing scoring (including defragmentation of accepted KV). then we can use top-1 and top-k implementations with different spec proposal methods (draft, medusa, eagle, ngram, etc).

also, we can wait for masking in kernels; we can also implement it in batch expansion style. it won't be as performant but could be a faster way to get everything built as we can add in the kernel support when it's ready.

cc @LiuXiaoxuanPKU

@zhyncs
Copy link
Contributor

zhyncs commented May 24, 2024

we can also implement it in batch expansion style

The performance will be particularly poor. I don't recommend doing this.

@zhyncs
Copy link
Contributor

zhyncs commented May 24, 2024

we can also implement it in batch expansion style

The performance will be particularly poor. I don't recommend doing this.

At the same time, supporting TreeMask in the Attention Kernel is not as much work as imagined. The prerequisite is to understand the implementation of the original casual mask. If you're interested, we can discuss the details further.

@zhyncs
Copy link
Contributor

zhyncs commented May 24, 2024

Especially when large batch size is considered, different request may have different acceptance path in the candidate tree.

From our internal experience, the real challenge is not integrating with continuous batching, but rather compatibility with existing features such as chunked prefill.

@KexinFeng
Copy link

KexinFeng commented May 24, 2024

we can also implement it in batch expansion style

The performance will be particularly poor. I don't recommend doing this.

At the same time, supporting TreeMask in the Attention Kernel is not as much work as imagined. The prerequisite is to understand the implementation of the original casual mask. If you're interested, we can discuss the details further.

About the treeMask in attention kernel, indeed we can combine the effort. Take a look at the discussion here: Dao-AILab/flash-attention#924. The flash-attention repo is also calling for such contribution. The api design can refer to the huggingface's implementation: huggingface/transformers#27539. This efficient tree attention kernel will be a crucial factor in tree-style speculative decoding, which we should consider prioritize.

@KexinFeng
Copy link

Especially when large batch size is considered, different request may have different acceptance path in the candidate tree.

From our internal experience, the real challenge is not integrating with continuous batching, but rather compatibility with existing features such as chunked prefill.

Right, generally speaking, speculative decoding (not necessarily the tree-style one) and chunked prefill both try to utilize the communication-bound computation ability and fight with each other. But there is a slight difference in their application scenarios. Chunked prefill is applied for long prompt input, while speculative decoding focuses on accelerating at small batch sizes. The incompatibility mentioned here is more of a trade-off between these two different scenarios and how to allocate the communication-bound computation budget.

@zhyncs
Copy link
Contributor

zhyncs commented May 24, 2024

This efficient tree attention kernel will be a crucial factor in tree-style speculative decoding, which we should consider prioritize.

In fact, our implementation is not based on Dao-AILab/flash-attention, but on the TurboMind 2.1 Attention Kernel, which was written from scratch and its performance is about 10% better than Dao-AILab/flash-attention.

The changes related to casual mask are roughly as follows:

// original
    __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K)
    {
        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
            if (offset_Q + qi < offset_K + si) {
                score -= std::numeric_limits<float>::infinity();
            }
        });
    }
// modified
    __device__ void ApplyCasualMask(
        FragS& frag_S, int offset_Q, int offset_K, const int* medusa_mask, int his_len, int input_len, int query_idx)
    {
        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
            if (medusa_mask) {
                int rel_pos_q = offset_Q + qi - his_len;
                int rel_pos_k = offset_K + si - his_len;
                if (0 <= rel_pos_q && rel_pos_q < input_len && 0 <= rel_pos_k && rel_pos_k < input_len) {
                    if (medusa_mask[rel_pos_q * input_len + rel_pos_k] == 0) {
                        score -= std::numeric_limits<float>::infinity();
                    }
                }
                else {
                    if (offset_Q + qi < offset_K + si) {
                        score -= std::numeric_limits<float>::infinity();
                    }
                }
            }
            else {
                if (offset_Q + qi < offset_K + si) {
                    score -= std::numeric_limits<float>::infinity();
                }
            }
        });
    }

@zhyncs
Copy link
Contributor

zhyncs commented May 24, 2024

The incompatibility mentioned here is more of a trade-off

This is related to the implementation of different frameworks, such as whether the previous framework has a sufficiently good abstract design and whether it is convenient for future expansion. It cannot be generalized. In comparison, vLLM is indeed more user-friendly in terms of secondary development difficulty and also has relatively strong scalability.

@cadedaniel
Copy link
Collaborator

If you're interested in combining chunked prefill and spec decode, see #5016. We have a naive dynamic speculation length policy which disables spec decode when the batch size gets too large.

@abhigoyal1997
Copy link
Contributor Author

Hi @cadedaniel
Did you get a chance to look at this?

@zhyncs
Copy link
Contributor

zhyncs commented May 28, 2024

also, we can wait for masking in kernels

@cadedaniel FlashInfer supports custom mask now flashinfer-ai/flashinfer#266

@zhyncs
Copy link
Contributor

zhyncs commented May 28, 2024

Hi @cadedaniel Did you get a chance to look at this?

@abhigoyal1997 May you resolve the conflicts first. Thanks.

@abhigoyal1997
Copy link
Contributor Author

@zhyncs Thanks for pointing out. I've resolved the conflicts 👍

vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/transformers_utils/config.py Show resolved Hide resolved
vllm/worker/embedding_model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@caddfa31434
Copy link

@abhigoyal1997 It seems that there are some issues with TP > 1.

@abhigoyal1997
Copy link
Contributor Author

abhigoyal1997 commented May 28, 2024

@caddfa31434 Thanks for testing and catching this. The problem was with the order of execution and hidden_states broadcasting in non-driver workers. The latest commit should fix these issues (I have tested for TP = 2).

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution -- glad to finally see Medusa working in open-source vLLM. Adding high-level feedback. Some other questions:

  • Can we add an e2e test with Medusa? we should expect greedy generation with Medusa (temp=0) to be equal to non-spec decode cases. You can follow this as an example.
    def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
    • We'll also want to cover cases like cuda graph / tp>1 / other non-greedy sampling params.
  • The biggest concern I have with this PR is the modification of prepare inputs for Medusa-specific models. It seems this PR is introducing two new things to prepare inputs -- allow on-GPU inputs, and also model-specific input config. Can we separate out these changes into their own PR to make things simpler? Additionally, can you walk me through the alternatives to model-specific input config in prepare inputs?

vllm/config.py Outdated
@@ -137,6 +139,13 @@ def __init__(
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)

self.extra_inputs: Dict[str, Tuple[Tuple[int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we list out the schema of what's allowed here for Medusa?

vllm/config.py Outdated
@@ -321,6 +331,10 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size

def set_num_lookahead_tokens(self, num_lookahead_tokens: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the relationship between num_lookahead_tokens and num medusa heads?

For context, the num_lookahead_tokens value is used to allocate KV space for speculative tokens. Since Medusa does not use KV, we shouldn't require this to be equal to num heads.

vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
Comment on lines +73 to +93
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use the lossless rejection sampler, we will have to run vLLM's standard sampling routine here -- the probability distribution must be modified in the same way as the scoring probability distributions, else you will get distributional drift in the output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate on the distribution shift? The tokens from the draft model are either accepted or rejected based on target model distribution, right? So even if the tokens from the draft are from a slightly different distribution, the final output should still match the target model distribution due to rejection. Is this understanding wrong or am I missing something?

The issue with using the standard sampling is that it was causing too much overhead. So if we do need to use it, we might need some optimizations there to get some speed-up out of Medusa.

Copy link
Contributor Author

@abhigoyal1997 abhigoyal1997 Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one case that I have noticed generates different tokens sometimes (not sure if this is what you are referring to though).
If without Medusa the logits of top-2 tokens have very close values (or same), then with Medusa those values sometimes change a little bit (I don't know why this is happening since Medusa shouldn't affect the output logits of the target model). This causes different tokens to be preferred by the target model, even for greedy sampling, depending on how those values change.

These images show this:
Screenshot 2024-06-05 at 6 14 04 PM
Screenshot 2024-06-05 at 6 17 23 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised this was happening because of bf16 precision, not seeing any such shift when using fp32.

vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/spec_decode/multi_head_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/multi_head_worker.py Outdated Show resolved Hide resolved
@abhigoyal1997 abhigoyal1997 changed the title [Speculative Decoding] Medusa Implementation [Speculative Decoding] Medusa Implementation with Top-1 proposer Jul 1, 2024
@abhigoyal1997
Copy link
Contributor Author

@LiuXiaoxuanPKU Can you please take a look?

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review! LGTM! But I have some minor questions & comments:

  1. It seems only greedy sampling is supported. Standard sampling is not supported.
  2. Could you add some end to end tests to make sure medusa generates almost same results as without speculative decoding? No need to be very strict. Take a look at this.

Happy to get this pr merged soon, sorry for the delay!

@abhigoyal1997
Copy link
Contributor Author

abhigoyal1997 commented Jul 9, 2024

Thanks @LiuXiaoxuanPKU for the review!

Sorry for the late review! LGTM! But I have some minor questions & comments:

  1. It seems only greedy sampling is supported. Standard sampling is not supported.

Standard sampling works as well, I have the results showing speed-up with temperature=1 in the PR description. Since I have only added the top-1 proposal candidate in this PR, sampling in the Medusa model is just using the argmax (we can extend this to top-k when we add tree speculation). My thought was that even when sampling with a temperature of 1, the target model would more likely choose the top-1 token because that still has the highest probability, so if we choose the top-1 token from Medusa, it still has more chance of being accepted than other tokens (the chance that token sampled from target and Medusa head would match when both are random would be much lower). Am I missing something in this argument? Even in Medusa paper, the candidates are only formed by top-k tokens.

  1. Could you add some end to end tests to make sure medusa generates almost same results as without speculative decoding? No need to be very strict. Take a look at this.

Thanks for the reference! I've added similar test for Medusa here: https://github.com/flipkart-incubator/vllm/blob/medusa/tests/spec_decode/e2e/test_medusa_correctness.py
Currently, it uses some random weights for Medusa heads along with JackFram/llama-68m as the base model since with actual ones (using Vicuna-1.3) I was getting OOM in CI.

Happy to get this pr merged soon, sorry for the delay!

@LiuXiaoxuanPKU LiuXiaoxuanPKU merged commit 2416b26 into vllm-project:main Jul 10, 2024
70 checks passed
adityagoel14 pushed a commit to adityagoel14/vllm-torchrun-test that referenced this pull request Jul 10, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 17, 2024
@hustxiayang
Copy link

hustxiayang commented Jul 25, 2024

Hi, I assume that you got Mistral-7B-Instruct-v0.2 + medusa from https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa, and for Meta-Llama-3-8B-Instruct + medusa, you trained Mudusa-1 by yourself using the codebases from https://github.com/FasterDecoding/Medusa/tree/main, but how many medusa heads you used for training?

@hustxiayang
Copy link

Hi, it seems that compared to the speedups of MLPspeculator, the speedups are much smaller: #4947. Do you have any insights on when we should use medusa instead of MLPspeculator?

@abhigoyal1997
Copy link
Contributor Author

Hi, I assume that you got Mistral-7B-Instruct-v0.2 + medusa from https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa, and for Meta-Llama-3-8B-Instruct + medusa, you trained Mudusa-1 by yourself using the codebases from https://github.com/FasterDecoding/Medusa/tree/main, but how many medusa heads you used for training?

Hi @hustxiayang
Actually I trained my own Medusa heads for both Mistral and Llama3 (will soon release them on HF as well). As for the number of heads, I used 4.

Hi, it seems that compared to the speedups of MLPspeculator, the speedups are much smaller: #4947. Do you have any insights on when we should use medusa instead of MLPspeculator?

I am not sure whether there are specific guidelines on when one should be used over the other. As for the performance, it depends on the dataset used, training method etc. and the results I posted in the description are for a baseline model I trained which can be improved with more work. Tree speculation and verification using MQA should improve the performance as well.

@hustxiayang
Copy link

Hi, it seems that this implementation is based on Medusa version 1, which loads lm_heads for each medusa heads. In Medusa version 2, it proposed to reuse the lm_heads from the base model. Have you already investigated on this and do you plan to implement it? Thanks!

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
10 participants